Skip to content

Sending operations asynchronously #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ repos:
additional_dependencies:
- neptune-api==0.4.0
- more-itertools
- backoff
default_language_version:
python: python3
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ python = "^3.8"

neptune-api = "0.4.0"
more-itertools = "^10.0.0"
psutil = "^5.0.0"
backoff = "^2.0.0"

[tool.poetry]
name = "neptune-client-scale"
Expand Down Expand Up @@ -74,6 +76,8 @@ force_grid_wrap = 2

[tool.ruff]
line-length = 120
target-version = "py38"
ignore = ["UP006", "UP007"]

[tool.ruff.lint]
select = ["F", "UP"]
Expand Down
175 changes: 135 additions & 40 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,40 @@

__all__ = ["Run"]

import atexit
import multiprocessing
import os
import threading
import time
from contextlib import AbstractContextManager
from datetime import datetime
from typing import Callable
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.synchronize import Condition as ConditionT
from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Set,
Union,
)

from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.api_client import ApiClient
from neptune_scale.core.components.abstract import (
Resource,
WithResources,
)
from neptune_scale.core.components.errors_monitor import ErrorsMonitor
from neptune_scale.core.components.errors_queue import ErrorsQueue
from neptune_scale.core.components.errors_tracking import (
ErrorsMonitor,
ErrorsQueue,
)
from neptune_scale.core.components.operations_queue import OperationsQueue
from neptune_scale.core.components.sync_process import SyncProcess
from neptune_scale.core.logger import logger
from neptune_scale.core.metadata_splitter import MetadataSplitter
from neptune_scale.core.serialization import (
datetime_to_proto,
Expand All @@ -44,6 +60,8 @@
MAX_FAMILY_LENGTH,
MAX_QUEUE_SIZE,
MAX_RUN_ID_LENGTH,
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
STOP_MESSAGE_FREQUENCY,
)


Expand All @@ -57,15 +75,17 @@ def __init__(
*,
family: str,
run_id: str,
project: str | None = None,
api_token: str | None = None,
project: Optional[str] = None,
api_token: Optional[str] = None,
resume: bool = False,
as_experiment: str | None = None,
creation_time: datetime | None = None,
from_run_id: str | None = None,
from_step: int | float | None = None,
mode: Literal["async", "disabled"] = "async",
as_experiment: Optional[str] = None,
creation_time: Optional[datetime] = None,
from_run_id: Optional[str] = None,
from_step: Optional[Union[int, float]] = None,
max_queue_size: int = MAX_QUEUE_SIZE,
max_queue_size_exceeded_callback: Callable[[int, BaseException], None] | None = None,
max_queue_size_exceeded_callback: Optional[Callable[[BaseException], None]] = None,
on_network_error_callback: Optional[Callable[[BaseException], None]] = None,
) -> None:
"""
Initializes a run that logs the model-building metadata to Neptune.
Expand All @@ -79,15 +99,15 @@ def __init__(
api_token: Your Neptune API token. If not provided, the value of the `NEPTUNE_API_TOKEN` environment
variable is used.
resume: Whether to resume an existing run.
mode: Mode of operation. If set to "disabled", the run doesn't log any metadata.
as_experiment: If creating a run as an experiment, ID of an experiment to be associated with the run.
creation_time: Custom creation time of the run.
from_run_id: If forking from an existing run, ID of the run to fork from.
from_step: If forking from an existing run, step number to fork from.
max_queue_size: Maximum number of operations in a queue.
max_queue_size_exceeded_callback: Callback function triggered when a queue is full.
Accepts two arguments:
- Maximum size of the queue.
- Exception that made the queue full.
max_queue_size_exceeded_callback: Callback function triggered when the queue is full. The function should take the exception
that made the queue full as its argument.
on_network_error_callback: Callback function triggered when a network error occurs.
"""
verify_type("family", family, str)
verify_type("run_id", run_id, str)
Expand Down Expand Up @@ -143,13 +163,33 @@ def __init__(

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback
lock=self._lock,
max_size=max_queue_size,
)
self._errors_queue: ErrorsQueue = ErrorsQueue()
self._errors_monitor = ErrorsMonitor(errors_queue=self._errors_queue)
self._backend: ApiClient = ApiClient(api_token=input_api_token)
self._errors_monitor = ErrorsMonitor(
errors_queue=self._errors_queue,
max_queue_size_exceeded_callback=max_queue_size_exceeded_callback,
on_network_error_callback=on_network_error_callback,
)
self._last_put_seq: Synchronized[int] = multiprocessing.Value("i", -1)
self._last_put_seq_wait: ConditionT = multiprocessing.Condition()
self._sync_process = SyncProcess(
family=self._family,
operations_queue=self._operations_queue.queue,
errors_queue=self._errors_queue,
api_token=input_api_token,
last_put_seq=self._last_put_seq,
last_put_seq_wait=self._last_put_seq_wait,
max_queue_size=max_queue_size,
mode=mode,
)

self._errors_monitor.start()
with self._lock:
self._sync_process.start()

self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close)

if not resume:
self._create_run(
Expand All @@ -159,32 +199,44 @@ def __init__(
from_step=from_step,
)

def __enter__(self) -> Run:
return self

@property
def resources(self) -> tuple[Resource, ...]:
return (
self._errors_queue,
self._operations_queue,
self._backend,
self._errors_monitor,
self._errors_queue,
)

def _close(self) -> None:
# TODO: Change to wait for all operations to be processed
with self._lock:
if self._sync_process.is_alive():
self.wait_for_submission()
self._sync_process.terminate()
self._sync_process.join()

self._errors_monitor.interrupt()
self._errors_monitor.join()

super().close()

def close(self) -> None:
"""
Stops the connection to Neptune and synchronizes all data.
"""
super().close()
if self._exit_func is not None:
atexit.unregister(self._exit_func)
self._exit_func = None
self._close()

def _create_run(
self,
creation_time: datetime,
as_experiment: str | None,
from_run_id: str | None,
from_step: int | float | None,
as_experiment: Optional[str],
from_run_id: Optional[str],
from_step: Optional[Union[int, float]],
) -> None:
fork_point: ForkPoint | None = None
fork_point: Optional[ForkPoint] = None
if from_run_id is not None and from_step is not None:
fork_point = ForkPoint(
parent_project=self._project, parent_run_id=from_run_id, step=make_step(number=from_step)
Expand All @@ -200,18 +252,16 @@ def _create_run(
creation_time=None if creation_time is None else datetime_to_proto(creation_time),
),
)
self._backend.submit(operation=operation, family=self._family)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
self._operations_queue.enqueue(operation=operation)

def log(
self,
step: float | int | None = None,
timestamp: datetime | None = None,
fields: dict[str, float | bool | int | str | datetime | list | set] | None = None,
metrics: dict[str, float] | None = None,
add_tags: dict[str, list[str] | set[str]] | None = None,
remove_tags: dict[str, list[str] | set[str]] | None = None,
step: Optional[Union[float, int]] = None,
timestamp: Optional[datetime] = None,
fields: Optional[Dict[str, Union[float, bool, int, str, datetime, list, set]]] = None,
metrics: Optional[Dict[str, float]] = None,
add_tags: Optional[Dict[str, Union[List[str], Set[str]]]] = None,
remove_tags: Optional[Dict[str, Union[List[str], Set[str]]]] = None,
) -> None:
"""
Logs the specified metadata to Neptune.
Expand Down Expand Up @@ -268,6 +318,51 @@ def log(
)

for operation in splitter:
self._backend.submit(operation=operation, family=self._family)
# TODO: Enqueue on the operations queue
# self._operations_queue.enqueue(operation=operation)
self._operations_queue.enqueue(operation=operation)

def wait_for_submission(self, timeout: Optional[float] = None) -> None:
"""
Waits until all metadata is submitted to Neptune.
"""
begin_time = time.time()
logger.info("Waiting for all operations to be processed")
if timeout is None:
logger.warning("No timeout specified. Waiting indefinitely")

with self._lock:
if not self._sync_process.is_alive():
logger.warning("Sync process is not running")
return # No need to wait if the sync process is not running

sleep_time_wait = (
min(MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, timeout) if timeout is not None else MINIMAL_WAIT_FOR_PUT_SLEEP_TIME
)
last_queued_sequence_id = self._operations_queue.last_sequence_id
last_message_printed: Optional[float] = None
while True:
with self._last_put_seq_wait:
self._last_put_seq_wait.wait(timeout=sleep_time_wait)
value = self._last_put_seq.value
if value == -1:
if self._operations_queue.last_sequence_id != -1:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting. No operations processed yet. Operations to sync: %s",
self._operations_queue.last_sequence_id + 1,
)
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info("Waiting. No operations processed yet")
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting for remaining %d operation(s) to be synced",
last_queued_sequence_id - value + 1,
)
if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout):
break

logger.info("All operations processed")
Loading
Loading