-
Notifications
You must be signed in to change notification settings - Fork 1
feature: Implement saving operations to / reading from the on-disk storage #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
a882291
07ff04d
0e9eead
d7cbf7f
dcc80aa
49f6dcb
e1b6b37
89d7fc3
bb1cd6b
f3313ea
82a85ba
8bee433
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,14 @@ | |
|
||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from types import TracebackType | ||
|
||
from neptune_scale.sync.operations_repository import ( | ||
Metadata, | ||
OperationsRepository, | ||
) | ||
|
||
__all__ = ["Run"] | ||
|
||
import atexit | ||
|
@@ -22,7 +30,6 @@ | |
|
||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint | ||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun | ||
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation | ||
|
||
from neptune_scale.api.attribute import AttributeStore | ||
from neptune_scale.api.metrics import Metrics | ||
|
@@ -35,6 +42,7 @@ | |
) | ||
from neptune_scale.exceptions import ( | ||
NeptuneApiTokenNotProvided, | ||
NeptuneConflictingDataInLocalStorage, | ||
NeptuneProjectNotProvided, | ||
) | ||
from neptune_scale.net.serialization import ( | ||
|
@@ -46,20 +54,15 @@ | |
ErrorsQueue, | ||
) | ||
from neptune_scale.sync.lag_tracking import LagTracker | ||
from neptune_scale.sync.operations_queue import OperationsQueue | ||
from neptune_scale.sync.parameters import ( | ||
MAX_EXPERIMENT_NAME_LENGTH, | ||
MAX_QUEUE_SIZE, | ||
MAX_RUN_ID_LENGTH, | ||
MINIMAL_WAIT_FOR_ACK_SLEEP_TIME, | ||
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, | ||
STOP_MESSAGE_FREQUENCY, | ||
) | ||
from neptune_scale.sync.sequence_tracker import SequenceTracker | ||
from neptune_scale.sync.sync_process import SyncProcess | ||
from neptune_scale.util.abstract import ( | ||
Resource, | ||
WithResources, | ||
) | ||
from neptune_scale.util.envs import ( | ||
API_TOKEN_ENV_NAME, | ||
PROJECT_ENV_NAME, | ||
|
@@ -74,7 +77,7 @@ | |
logger = get_logger() | ||
|
||
|
||
class Run(WithResources, AbstractContextManager): | ||
class Run(AbstractContextManager): | ||
""" | ||
Representation of tracked metadata. | ||
""" | ||
|
@@ -91,7 +94,7 @@ def __init__( | |
creation_time: Optional[datetime] = None, | ||
fork_run_id: Optional[str] = None, | ||
fork_step: Optional[Union[int, float]] = None, | ||
max_queue_size: int = MAX_QUEUE_SIZE, | ||
max_queue_size: Optional[int] = None, | ||
async_lag_threshold: Optional[float] = None, | ||
on_async_lag_callback: Optional[Callable[[], None]] = None, | ||
on_queue_full_callback: Optional[Callable[[BaseException, Optional[float]], None]] = None, | ||
|
@@ -114,7 +117,7 @@ def __init__( | |
creation_time: Custom creation time of the run. | ||
fork_run_id: If forking from an existing run, ID of the run to fork from. | ||
fork_step: If forking from an existing run, step number to fork from. | ||
max_queue_size: Maximum number of operations in a queue. | ||
max_queue_size: Deprecated. | ||
async_lag_threshold: Threshold for the duration between the queueing and synchronization of an operation | ||
(in seconds). If the duration exceeds the threshold, the callback function is triggered. | ||
on_async_lag_callback: Callback function triggered when the duration between the queueing and synchronization | ||
|
@@ -134,7 +137,6 @@ def __init__( | |
verify_type("creation_time", creation_time, (datetime, type(None))) | ||
verify_type("fork_run_id", fork_run_id, (str, type(None))) | ||
verify_type("fork_step", fork_step, (int, float, type(None))) | ||
verify_type("max_queue_size", max_queue_size, int) | ||
verify_type("async_lag_threshold", async_lag_threshold, (int, float, type(None))) | ||
verify_type("on_async_lag_callback", on_async_lag_callback, (Callable, type(None))) | ||
verify_type("on_queue_full_callback", on_queue_full_callback, (Callable, type(None))) | ||
|
@@ -161,8 +163,8 @@ def __init__( | |
): | ||
raise ValueError("`on_async_lag_callback` must be used with `async_lag_threshold`.") | ||
|
||
if max_queue_size < 1: | ||
raise ValueError("`max_queue_size` must be greater than 0.") | ||
if max_queue_size is not None: | ||
logger.warning("`max_queue_size` is deprecated and will be removed in a future version.") | ||
|
||
project = project or os.environ.get(PROJECT_ENV_NAME) | ||
if project: | ||
|
@@ -198,12 +200,25 @@ def __init__( | |
self._run_id: str = run_id | ||
|
||
self._lock = threading.RLock() | ||
self._operations_queue: OperationsQueue = OperationsQueue( | ||
lock=self._lock, | ||
max_size=max_queue_size, | ||
) | ||
|
||
self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue) | ||
operations_repository_path = _create_repository_path(self._project, self._run_id) | ||
self._operations_repo: OperationsRepository = OperationsRepository( | ||
db_path=operations_repository_path, | ||
) | ||
self._operations_repo.init_db() | ||
pitercl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
existing_metadata = self._operations_repo.get_metadata() | ||
|
||
# Save metadata if it doesn't exist | ||
if existing_metadata is None: | ||
self._operations_repo.save_metadata(self._project, self._run_id, fork_run_id, fork_step) | ||
# Check for conflicts when not resuming | ||
elif not resume: | ||
self._check_for_run_conflicts(existing_metadata, fork_run_id, fork_step) | ||
|
||
self._sequence_tracker = SequenceTracker() | ||
pitercl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._attr_store: AttributeStore = AttributeStore( | ||
self._project, self._run_id, self._operations_repo, self._sequence_tracker | ||
) | ||
|
||
self._errors_queue: ErrorsQueue = ErrorsQueue() | ||
self._errors_monitor = ErrorsMonitor( | ||
|
@@ -222,21 +237,20 @@ def __init__( | |
self._sync_process = SyncProcess( | ||
project=self._project, | ||
family=self._run_id, | ||
operations_queue=self._operations_queue.queue, | ||
operations_repository_path=operations_repository_path, | ||
errors_queue=self._errors_queue, | ||
process_link=self._process_link, | ||
api_token=input_api_token, | ||
last_queued_seq=self._last_queued_seq, | ||
last_ack_seq=self._last_ack_seq, | ||
last_ack_timestamp=self._last_ack_timestamp, | ||
max_queue_size=max_queue_size, | ||
mode=mode, | ||
) | ||
self._lag_tracker: Optional[LagTracker] = None | ||
if async_lag_threshold is not None and on_async_lag_callback is not None: | ||
self._lag_tracker = LagTracker( | ||
errors_queue=self._errors_queue, | ||
operations_queue=self._operations_queue, | ||
sequence_tracker=self._sequence_tracker, | ||
last_ack_timestamp=self._last_ack_timestamp, | ||
async_lag_threshold=async_lag_threshold, | ||
on_async_lag_callback=on_async_lag_callback, | ||
|
@@ -251,6 +265,7 @@ def __init__( | |
self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close) | ||
|
||
if not resume: | ||
# Create a new run | ||
self._create_run( | ||
creation_time=datetime.now() if creation_time is None else creation_time, | ||
experiment_name=experiment_name, | ||
|
@@ -259,28 +274,28 @@ def __init__( | |
) | ||
self.wait_for_processing(verbose=False) | ||
|
||
def _check_for_run_conflicts( | ||
pitercl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, existing_metadata: Metadata, fork_run_id: Optional[str], fork_step: Optional[float] | ||
) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this method should become something like
|
||
if existing_metadata.project != self._project or existing_metadata.run_id != self._run_id: | ||
# should never happen because we use project and run_id to create the repository path | ||
raise NeptuneConflictingDataInLocalStorage() | ||
if existing_metadata.parent_run_id == fork_run_id and existing_metadata.fork_step == fork_step: | ||
logger.warning( | ||
"Run already exists in local storage with the same parent run and fork point. Resuming the run." | ||
) | ||
return | ||
else: | ||
# Same run_id but different fork points | ||
raise NeptuneConflictingDataInLocalStorage() | ||
|
||
def _on_child_link_closed(self, _: ProcessLink) -> None: | ||
with self._lock: | ||
if not self._is_closing: | ||
logger.error("Child process closed unexpectedly.") | ||
self._is_closing = True | ||
self.terminate() | ||
|
||
@property | ||
def resources(self) -> tuple[Resource, ...]: | ||
if self._lag_tracker is not None: | ||
return ( | ||
self._errors_queue, | ||
self._operations_queue, | ||
self._lag_tracker, | ||
self._errors_monitor, | ||
) | ||
return ( | ||
self._errors_queue, | ||
self._operations_queue, | ||
self._errors_monitor, | ||
) | ||
|
||
def _close(self, *, wait: bool = True) -> None: | ||
with self._lock: | ||
if self._is_closing: | ||
|
@@ -310,7 +325,8 @@ def _close(self, *, wait: bool = True) -> None: | |
if threading.current_thread() != self._errors_monitor: | ||
self._errors_monitor.join() | ||
|
||
super().close() | ||
self._operations_repo.close() | ||
pitercl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._errors_queue.close() | ||
|
||
def terminate(self) -> None: | ||
""" | ||
|
@@ -363,6 +379,14 @@ def close(self) -> None: | |
self._exit_func = None | ||
self._close(wait=True) | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[type[BaseException]], | ||
exc_value: Optional[BaseException], | ||
traceback: Optional[TracebackType], | ||
) -> None: | ||
self.close() | ||
|
||
def _create_run( | ||
self, | ||
creation_time: datetime, | ||
|
@@ -376,17 +400,15 @@ def _create_run( | |
parent_project=self._project, parent_run_id=fork_run_id, step=make_step(number=fork_step) | ||
) | ||
|
||
operation = RunOperation( | ||
project=self._project, | ||
run_id=self._run_id, | ||
create=CreateRun( | ||
family=self._run_id, | ||
fork_point=fork_point, | ||
experiment_id=experiment_name, | ||
creation_time=None if creation_time is None else datetime_to_proto(creation_time), | ||
), | ||
create_run = CreateRun( | ||
family=self._run_id, | ||
fork_point=fork_point, | ||
experiment_id=experiment_name, | ||
creation_time=None if creation_time is None else datetime_to_proto(creation_time), | ||
) | ||
self._operations_queue.enqueue(operation=operation) | ||
|
||
sequence = self._operations_repo.save_create_run(create_run) | ||
self._sequence_tracker.update_sequence_id(sequence) | ||
|
||
def log_metrics( | ||
self, | ||
|
@@ -604,20 +626,20 @@ def _wait( | |
|
||
# Handle the case where we get notified on `wait_seq` before we actually wait. | ||
# Otherwise, we would unnecessarily block, waiting on a notify_all() that never happens. | ||
if wait_seq.value >= self._operations_queue.last_sequence_id: | ||
if wait_seq.value >= self._sequence_tracker.last_sequence_id: | ||
break | ||
|
||
with wait_seq: | ||
wait_seq.wait(timeout=wait_time) | ||
value = wait_seq.value | ||
|
||
last_queued_sequence_id = self._operations_queue.last_sequence_id | ||
last_queued_sequence_id = self._sequence_tracker.last_sequence_id | ||
|
||
if value == -1: | ||
if self._operations_queue.last_sequence_id != -1: | ||
if self._sequence_tracker.last_sequence_id != -1: | ||
last_print_timestamp = print_message( | ||
f"Waiting. No operations were {phrase} yet. Operations to sync: %s", | ||
self._operations_queue.last_sequence_id + 1, | ||
self._sequence_tracker.last_sequence_id + 1, | ||
last_print=last_print_timestamp, | ||
verbose=verbose, | ||
) | ||
|
@@ -692,3 +714,8 @@ def print_message(msg: str, *args: Any, last_print: Optional[float] = None, verb | |
return current_time | ||
|
||
return last_print | ||
|
||
|
||
def _create_repository_path(project: str, run_id: str) -> Path: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: ideally we need to come up with something unambiguous, but usable (url encoding is not usable). Need to think some more. |
||
sanitized_project = project.replace("/", "_") | ||
return Path(os.getcwd()) / ".neptune" / f"{sanitized_project}_{run_id}.sqlite3" |
Uh oh!
There was an error while loading. Please reload this page.