Skip to content

feature: Implement saving operations to / reading from the on-disk storage #149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests-e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
pip install -r dev_requirements.txt

- name: Run tests
timeout-minutes: 30
env:
NEPTUNE_API_TOKEN: ${{ secrets.E2E_API_TOKEN }}
NEPTUNE_E2E_PROJECT: ${{ secrets.E2E_PROJECT }}
Expand Down
17 changes: 11 additions & 6 deletions src/neptune_scale/api/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from neptune_scale.api.metrics import Metrics
from neptune_scale.sync.metadata_splitter import MetadataSplitter
from neptune_scale.sync.operations_queue import OperationsQueue
from neptune_scale.sync.operations_repository import OperationsRepository
from neptune_scale.sync.sequence_tracker import SequenceTracker

__all__ = ("Attribute", "AttributeStore")

Expand Down Expand Up @@ -65,11 +66,14 @@ class AttributeStore:
end consuming the queue (which would be SyncProcess).
"""

def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) -> None:
def __init__(
self, project: str, run_id: str, operations_repo: OperationsRepository, sequence_tracker: SequenceTracker
) -> None:
self._project = project
self._run_id = run_id
self._operations_queue = operations_queue
self._operations_repo = operations_repo
self._attributes: dict[str, Attribute] = {}
self._sequence_tracker = sequence_tracker

def __getitem__(self, path: str) -> "Attribute":
path = cleanup_path(path)
Expand Down Expand Up @@ -108,9 +112,10 @@ def log(
remove_tags=tags_remove,
)

for operation, metadata_size in splitter:
key = metrics.batch_key() if metrics is not None else None
self._operations_queue.enqueue(operation=operation, size=metadata_size, key=key)
operations = list(splitter)
sequence_id = self._operations_repo.save_update_run_snapshots(operations)

self._sequence_tracker.update_sequence_id(sequence_id)


class Attribute:
Expand Down
4 changes: 0 additions & 4 deletions src/neptune_scale/api/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections.abc import Hashable
from dataclasses import dataclass
from typing import (
Optional,
Expand Down Expand Up @@ -40,6 +39,3 @@ def __post_init__(self) -> None:
self.preview_completion = None
if self.preview_completion is not None:
verify_value_between("preview_completion", self.preview_completion, 0.0, 1.0)

def batch_key(self) -> Hashable:
return (self.step, self.preview, self.preview_completion)
126 changes: 75 additions & 51 deletions src/neptune_scale/api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

from __future__ import annotations

from pathlib import Path
from types import TracebackType

from neptune_scale.sync.operations_repository import (
Metadata,
OperationsRepository,
)

__all__ = ["Run"]

import atexit
Expand All @@ -22,7 +30,6 @@

from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation

from neptune_scale.api.attribute import AttributeStore
from neptune_scale.api.metrics import Metrics
Expand All @@ -36,6 +43,8 @@
from neptune_scale.exceptions import (
NeptuneApiTokenNotProvided,
NeptuneProjectNotProvided,
NeptuneRunConflicting,
NeptuneRunDuplicate,
)
from neptune_scale.net.serialization import (
datetime_to_proto,
Expand All @@ -46,20 +55,15 @@
ErrorsQueue,
)
from neptune_scale.sync.lag_tracking import LagTracker
from neptune_scale.sync.operations_queue import OperationsQueue
from neptune_scale.sync.parameters import (
MAX_EXPERIMENT_NAME_LENGTH,
MAX_QUEUE_SIZE,
MAX_RUN_ID_LENGTH,
MINIMAL_WAIT_FOR_ACK_SLEEP_TIME,
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
STOP_MESSAGE_FREQUENCY,
)
from neptune_scale.sync.sequence_tracker import SequenceTracker
from neptune_scale.sync.sync_process import SyncProcess
from neptune_scale.util.abstract import (
Resource,
WithResources,
)
from neptune_scale.util.envs import (
API_TOKEN_ENV_NAME,
PROJECT_ENV_NAME,
Expand All @@ -74,7 +78,7 @@
logger = get_logger()


class Run(WithResources, AbstractContextManager):
class Run(AbstractContextManager):
"""
Representation of tracked metadata.
"""
Expand All @@ -91,7 +95,7 @@ def __init__(
creation_time: Optional[datetime] = None,
fork_run_id: Optional[str] = None,
fork_step: Optional[Union[int, float]] = None,
max_queue_size: int = MAX_QUEUE_SIZE,
max_queue_size: Optional[int] = None,
async_lag_threshold: Optional[float] = None,
on_async_lag_callback: Optional[Callable[[], None]] = None,
on_queue_full_callback: Optional[Callable[[BaseException, Optional[float]], None]] = None,
Expand All @@ -114,7 +118,6 @@ def __init__(
creation_time: Custom creation time of the run.
fork_run_id: If forking from an existing run, ID of the run to fork from.
fork_step: If forking from an existing run, step number to fork from.
max_queue_size: Maximum number of operations in a queue.
async_lag_threshold: Threshold for the duration between the queueing and synchronization of an operation
(in seconds). If the duration exceeds the threshold, the callback function is triggered.
on_async_lag_callback: Callback function triggered when the duration between the queueing and synchronization
Expand All @@ -134,7 +137,6 @@ def __init__(
verify_type("creation_time", creation_time, (datetime, type(None)))
verify_type("fork_run_id", fork_run_id, (str, type(None)))
verify_type("fork_step", fork_step, (int, float, type(None)))
verify_type("max_queue_size", max_queue_size, int)
verify_type("async_lag_threshold", async_lag_threshold, (int, float, type(None)))
verify_type("on_async_lag_callback", on_async_lag_callback, (Callable, type(None)))
verify_type("on_queue_full_callback", on_queue_full_callback, (Callable, type(None)))
Expand All @@ -161,8 +163,8 @@ def __init__(
):
raise ValueError("`on_async_lag_callback` must be used with `async_lag_threshold`.")

if max_queue_size < 1:
raise ValueError("`max_queue_size` must be greater than 0.")
if max_queue_size is not None:
logger.warning("`max_queue_size` is deprecated and will be removed in a future version.")

project = project or os.environ.get(PROJECT_ENV_NAME)
if project:
Expand Down Expand Up @@ -198,12 +200,25 @@ def __init__(
self._run_id: str = run_id

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
lock=self._lock,
max_size=max_queue_size,
)

self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue)
operations_repository_path = _create_repository_path(self._project, self._run_id)
self._operations_repo: OperationsRepository = OperationsRepository(
db_path=operations_repository_path,
)
self._operations_repo.init_db()
existing_metadata = self._operations_repo.get_metadata()

# Save metadata if it doesn't exist
if existing_metadata is None:
self._operations_repo.save_metadata(self._project, self._run_id, fork_run_id, fork_step)
# Check for conflicts when not resuming
elif not resume:
self._check_for_run_conflicts(existing_metadata, fork_run_id, fork_step)

self._sequence_tracker = SequenceTracker()
self._attr_store: AttributeStore = AttributeStore(
self._project, self._run_id, self._operations_repo, sequence_tracker=self._sequence_tracker
)

self._errors_queue: ErrorsQueue = ErrorsQueue()
self._errors_monitor = ErrorsMonitor(
Expand All @@ -222,21 +237,20 @@ def __init__(
self._sync_process = SyncProcess(
project=self._project,
family=self._run_id,
operations_queue=self._operations_queue.queue,
operations_repository_path=operations_repository_path,
errors_queue=self._errors_queue,
process_link=self._process_link,
api_token=input_api_token,
last_queued_seq=self._last_queued_seq,
last_ack_seq=self._last_ack_seq,
last_ack_timestamp=self._last_ack_timestamp,
max_queue_size=max_queue_size,
mode=mode,
)
self._lag_tracker: Optional[LagTracker] = None
if async_lag_threshold is not None and on_async_lag_callback is not None:
self._lag_tracker = LagTracker(
errors_queue=self._errors_queue,
operations_queue=self._operations_queue,
sequence_tracker=self._sequence_tracker,
last_ack_timestamp=self._last_ack_timestamp,
async_lag_threshold=async_lag_threshold,
on_async_lag_callback=on_async_lag_callback,
Expand All @@ -251,6 +265,7 @@ def __init__(
self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close)

if not resume:
# Create a new run
self._create_run(
creation_time=datetime.now() if creation_time is None else creation_time,
experiment_name=experiment_name,
Expand All @@ -259,28 +274,25 @@ def __init__(
)
self.wait_for_processing(verbose=False)

def _check_for_run_conflicts(
self, existing_metadata: Metadata, fork_run_id: Optional[str], fork_step: Optional[float]
) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this method should become something like _validate_existing_db() and within it there should be:

  1. Check if the version matches. If not -> NeptuneLocalStorageInUnsupportedVersion
  2. Check if all other metadata matches. If not -> NeptuneConflictingDataInLocalStorage; if it does -> warning.

if existing_metadata.project != self._project or existing_metadata.run_id != self._run_id:
# should never happen because we use project and run_id to create the repository path
raise NeptuneRunConflicting()
if existing_metadata.parent_run_id == fork_run_id and existing_metadata.fork_step == fork_step:
raise NeptuneRunDuplicate()
else:
# Same run_id but different fork points
raise NeptuneRunConflicting()

def _on_child_link_closed(self, _: ProcessLink) -> None:
with self._lock:
if not self._is_closing:
logger.error("Child process closed unexpectedly.")
self._is_closing = True
self.terminate()

@property
def resources(self) -> tuple[Resource, ...]:
if self._lag_tracker is not None:
return (
self._errors_queue,
self._operations_queue,
self._lag_tracker,
self._errors_monitor,
)
return (
self._errors_queue,
self._operations_queue,
self._errors_monitor,
)

def _close(self, *, wait: bool = True) -> None:
with self._lock:
if self._is_closing:
Expand Down Expand Up @@ -310,7 +322,8 @@ def _close(self, *, wait: bool = True) -> None:
if threading.current_thread() != self._errors_monitor:
self._errors_monitor.join()

super().close()
self._operations_repo.close()
self._errors_queue.close()

def terminate(self) -> None:
"""
Expand Down Expand Up @@ -363,6 +376,14 @@ def close(self) -> None:
self._exit_func = None
self._close(wait=True)

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.close()

def _create_run(
self,
creation_time: datetime,
Expand All @@ -376,17 +397,15 @@ def _create_run(
parent_project=self._project, parent_run_id=fork_run_id, step=make_step(number=fork_step)
)

operation = RunOperation(
project=self._project,
run_id=self._run_id,
create=CreateRun(
family=self._run_id,
fork_point=fork_point,
experiment_id=experiment_name,
creation_time=None if creation_time is None else datetime_to_proto(creation_time),
),
create_run = CreateRun(
family=self._run_id,
fork_point=fork_point,
experiment_id=experiment_name,
creation_time=None if creation_time is None else datetime_to_proto(creation_time),
)
self._operations_queue.enqueue(operation=operation)

sequence = self._operations_repo.save_create_run(create_run)
self._sequence_tracker.update_sequence_id(sequence)

def log_metrics(
self,
Expand Down Expand Up @@ -604,20 +623,20 @@ def _wait(

# Handle the case where we get notified on `wait_seq` before we actually wait.
# Otherwise, we would unnecessarily block, waiting on a notify_all() that never happens.
if wait_seq.value >= self._operations_queue.last_sequence_id:
if wait_seq.value >= self._sequence_tracker.last_sequence_id:
break

with wait_seq:
wait_seq.wait(timeout=wait_time)
value = wait_seq.value

last_queued_sequence_id = self._operations_queue.last_sequence_id
last_queued_sequence_id = self._sequence_tracker.last_sequence_id

if value == -1:
if self._operations_queue.last_sequence_id != -1:
if self._sequence_tracker.last_sequence_id != -1:
last_print_timestamp = print_message(
f"Waiting. No operations were {phrase} yet. Operations to sync: %s",
self._operations_queue.last_sequence_id + 1,
self._sequence_tracker.last_sequence_id + 1,
last_print=last_print_timestamp,
verbose=verbose,
)
Expand Down Expand Up @@ -692,3 +711,8 @@ def print_message(msg: str, *args: Any, last_print: Optional[float] = None, verb
return current_time

return last_print


def _create_repository_path(project: str, run_id: str) -> Path:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: ideally we need to come up with something unambiguous, but usable (url encoding is not usable). Need to think some more.

sanitized_project = project.replace("/", "_")
return Path(os.getcwd()) / f".neptune/{sanitized_project}_{run_id}.sqlite3"
5 changes: 3 additions & 2 deletions src/neptune_scale/net/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
NeptuneUnableToAuthenticateError,
)
from neptune_scale.sync.parameters import REQUEST_TIMEOUT
from neptune_scale.util.abstract import Resource
from neptune_scale.util.envs import ALLOW_SELF_SIGNED_CERTIFICATE
from neptune_scale.util.logger import get_logger

Expand Down Expand Up @@ -122,13 +121,15 @@ def create_auth_api_client(
)


class ApiClient(Resource, abc.ABC):
class ApiClient(abc.ABC):
@abc.abstractmethod
def submit(self, operation: RunOperation, family: str) -> Response[SubmitResponse]: ...

@abc.abstractmethod
def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ...

def close(self) -> None: ...


class HostedApiClient(ApiClient):
def __init__(self, api_token: str) -> None:
Expand Down
Loading
Loading