-
Notifications
You must be signed in to change notification settings - Fork 1
feature: Implement saving operations to / reading from the on-disk storage #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
a882291
feat: persistent mode
PatrykGala 07ff04d
feat: rebase
PatrykGala 0e9eead
feat: fix tests - remove SyncRun
PatrykGala d7cbf7f
feat: Review changes
PatrykGala dcc80aa
feat: Review changes
PatrykGala 49f6dcb
feat: Review changes
PatrykGala e1b6b37
feat: Add tests for sync process
PatrykGala 89d7fc3
feat: Revert elseif
PatrykGala bb1cd6b
feat: ValueError -> RuntimeError
PatrykGala f3313ea
feat: Add exceptions
PatrykGala 82a85ba
feat: Add exceptions - precommit
PatrykGala 8bee433
feat: Extract _validate_existing_db
PatrykGala File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,15 @@ | |
|
||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from types import TracebackType | ||
|
||
from neptune_scale.sync.operations_repository import ( | ||
DB_VERSION, | ||
Metadata, | ||
OperationsRepository, | ||
) | ||
|
||
__all__ = ["Run"] | ||
|
||
import atexit | ||
|
@@ -22,7 +31,6 @@ | |
|
||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint | ||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun | ||
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation | ||
|
||
from neptune_scale.api.attribute import AttributeStore | ||
from neptune_scale.api.metrics import Metrics | ||
|
@@ -35,6 +43,8 @@ | |
) | ||
from neptune_scale.exceptions import ( | ||
NeptuneApiTokenNotProvided, | ||
NeptuneConflictingDataInLocalStorage, | ||
NeptuneLocalStorageInUnsupportedVersion, | ||
NeptuneProjectNotProvided, | ||
) | ||
from neptune_scale.net.serialization import ( | ||
|
@@ -46,20 +56,15 @@ | |
ErrorsQueue, | ||
) | ||
from neptune_scale.sync.lag_tracking import LagTracker | ||
from neptune_scale.sync.operations_queue import OperationsQueue | ||
from neptune_scale.sync.parameters import ( | ||
MAX_EXPERIMENT_NAME_LENGTH, | ||
MAX_QUEUE_SIZE, | ||
MAX_RUN_ID_LENGTH, | ||
MINIMAL_WAIT_FOR_ACK_SLEEP_TIME, | ||
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, | ||
STOP_MESSAGE_FREQUENCY, | ||
) | ||
from neptune_scale.sync.sequence_tracker import SequenceTracker | ||
from neptune_scale.sync.sync_process import SyncProcess | ||
from neptune_scale.util.abstract import ( | ||
Resource, | ||
WithResources, | ||
) | ||
from neptune_scale.util.envs import ( | ||
API_TOKEN_ENV_NAME, | ||
PROJECT_ENV_NAME, | ||
|
@@ -74,7 +79,7 @@ | |
logger = get_logger() | ||
|
||
|
||
class Run(WithResources, AbstractContextManager): | ||
class Run(AbstractContextManager): | ||
""" | ||
Representation of tracked metadata. | ||
""" | ||
|
@@ -91,7 +96,7 @@ def __init__( | |
creation_time: Optional[datetime] = None, | ||
fork_run_id: Optional[str] = None, | ||
fork_step: Optional[Union[int, float]] = None, | ||
max_queue_size: int = MAX_QUEUE_SIZE, | ||
max_queue_size: Optional[int] = None, | ||
async_lag_threshold: Optional[float] = None, | ||
on_async_lag_callback: Optional[Callable[[], None]] = None, | ||
on_queue_full_callback: Optional[Callable[[BaseException, Optional[float]], None]] = None, | ||
|
@@ -114,7 +119,7 @@ def __init__( | |
creation_time: Custom creation time of the run. | ||
fork_run_id: If forking from an existing run, ID of the run to fork from. | ||
fork_step: If forking from an existing run, step number to fork from. | ||
max_queue_size: Maximum number of operations in a queue. | ||
max_queue_size: Deprecated. | ||
async_lag_threshold: Threshold for the duration between the queueing and synchronization of an operation | ||
(in seconds). If the duration exceeds the threshold, the callback function is triggered. | ||
on_async_lag_callback: Callback function triggered when the duration between the queueing and synchronization | ||
|
@@ -134,7 +139,6 @@ def __init__( | |
verify_type("creation_time", creation_time, (datetime, type(None))) | ||
verify_type("fork_run_id", fork_run_id, (str, type(None))) | ||
verify_type("fork_step", fork_step, (int, float, type(None))) | ||
verify_type("max_queue_size", max_queue_size, int) | ||
verify_type("async_lag_threshold", async_lag_threshold, (int, float, type(None))) | ||
verify_type("on_async_lag_callback", on_async_lag_callback, (Callable, type(None))) | ||
verify_type("on_queue_full_callback", on_queue_full_callback, (Callable, type(None))) | ||
|
@@ -161,8 +165,8 @@ def __init__( | |
): | ||
raise ValueError("`on_async_lag_callback` must be used with `async_lag_threshold`.") | ||
|
||
if max_queue_size < 1: | ||
raise ValueError("`max_queue_size` must be greater than 0.") | ||
if max_queue_size is not None: | ||
logger.warning("`max_queue_size` is deprecated and will be removed in a future version.") | ||
|
||
project = project or os.environ.get(PROJECT_ENV_NAME) | ||
if project: | ||
|
@@ -198,12 +202,24 @@ def __init__( | |
self._run_id: str = run_id | ||
|
||
self._lock = threading.RLock() | ||
self._operations_queue: OperationsQueue = OperationsQueue( | ||
lock=self._lock, | ||
max_size=max_queue_size, | ||
|
||
operations_repository_path = _create_repository_path(self._project, self._run_id) | ||
self._operations_repo: OperationsRepository = OperationsRepository( | ||
db_path=operations_repository_path, | ||
) | ||
self._operations_repo.init_db() | ||
pitercl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
existing_metadata = self._operations_repo.get_metadata() | ||
|
||
self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue) | ||
# Save metadata if it doesn't exist | ||
if existing_metadata is None: | ||
self._operations_repo.save_metadata(self._project, self._run_id, fork_run_id, fork_step) | ||
else: | ||
_validate_existing_db(existing_metadata, resume, self._project, self._run_id, fork_run_id, fork_step) | ||
|
||
self._sequence_tracker = SequenceTracker() | ||
pitercl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._attr_store: AttributeStore = AttributeStore( | ||
self._project, self._run_id, self._operations_repo, self._sequence_tracker | ||
) | ||
|
||
self._errors_queue: ErrorsQueue = ErrorsQueue() | ||
self._errors_monitor = ErrorsMonitor( | ||
|
@@ -222,21 +238,20 @@ def __init__( | |
self._sync_process = SyncProcess( | ||
project=self._project, | ||
family=self._run_id, | ||
operations_queue=self._operations_queue.queue, | ||
operations_repository_path=operations_repository_path, | ||
errors_queue=self._errors_queue, | ||
process_link=self._process_link, | ||
api_token=input_api_token, | ||
last_queued_seq=self._last_queued_seq, | ||
last_ack_seq=self._last_ack_seq, | ||
last_ack_timestamp=self._last_ack_timestamp, | ||
max_queue_size=max_queue_size, | ||
mode=mode, | ||
) | ||
self._lag_tracker: Optional[LagTracker] = None | ||
if async_lag_threshold is not None and on_async_lag_callback is not None: | ||
self._lag_tracker = LagTracker( | ||
errors_queue=self._errors_queue, | ||
operations_queue=self._operations_queue, | ||
sequence_tracker=self._sequence_tracker, | ||
last_ack_timestamp=self._last_ack_timestamp, | ||
async_lag_threshold=async_lag_threshold, | ||
on_async_lag_callback=on_async_lag_callback, | ||
|
@@ -251,6 +266,7 @@ def __init__( | |
self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close) | ||
|
||
if not resume: | ||
# Create a new run | ||
self._create_run( | ||
creation_time=datetime.now() if creation_time is None else creation_time, | ||
experiment_name=experiment_name, | ||
|
@@ -266,21 +282,6 @@ def _on_child_link_closed(self, _: ProcessLink) -> None: | |
self._is_closing = True | ||
self.terminate() | ||
|
||
@property | ||
def resources(self) -> tuple[Resource, ...]: | ||
if self._lag_tracker is not None: | ||
return ( | ||
self._errors_queue, | ||
self._operations_queue, | ||
self._lag_tracker, | ||
self._errors_monitor, | ||
) | ||
return ( | ||
self._errors_queue, | ||
self._operations_queue, | ||
self._errors_monitor, | ||
) | ||
|
||
def _close(self, *, wait: bool = True) -> None: | ||
with self._lock: | ||
if self._is_closing: | ||
|
@@ -310,7 +311,8 @@ def _close(self, *, wait: bool = True) -> None: | |
if threading.current_thread() != self._errors_monitor: | ||
self._errors_monitor.join() | ||
|
||
super().close() | ||
self._operations_repo.close() | ||
pitercl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._errors_queue.close() | ||
|
||
def terminate(self) -> None: | ||
""" | ||
|
@@ -363,6 +365,14 @@ def close(self) -> None: | |
self._exit_func = None | ||
self._close(wait=True) | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[type[BaseException]], | ||
exc_value: Optional[BaseException], | ||
traceback: Optional[TracebackType], | ||
) -> None: | ||
self.close() | ||
|
||
def _create_run( | ||
self, | ||
creation_time: datetime, | ||
|
@@ -376,17 +386,15 @@ def _create_run( | |
parent_project=self._project, parent_run_id=fork_run_id, step=make_step(number=fork_step) | ||
) | ||
|
||
operation = RunOperation( | ||
project=self._project, | ||
run_id=self._run_id, | ||
create=CreateRun( | ||
family=self._run_id, | ||
fork_point=fork_point, | ||
experiment_id=experiment_name, | ||
creation_time=None if creation_time is None else datetime_to_proto(creation_time), | ||
), | ||
create_run = CreateRun( | ||
family=self._run_id, | ||
fork_point=fork_point, | ||
experiment_id=experiment_name, | ||
creation_time=None if creation_time is None else datetime_to_proto(creation_time), | ||
) | ||
self._operations_queue.enqueue(operation=operation) | ||
|
||
sequence = self._operations_repo.save_create_run(create_run) | ||
self._sequence_tracker.update_sequence_id(sequence) | ||
|
||
def log_metrics( | ||
self, | ||
|
@@ -604,20 +612,20 @@ def _wait( | |
|
||
# Handle the case where we get notified on `wait_seq` before we actually wait. | ||
# Otherwise, we would unnecessarily block, waiting on a notify_all() that never happens. | ||
if wait_seq.value >= self._operations_queue.last_sequence_id: | ||
if wait_seq.value >= self._sequence_tracker.last_sequence_id: | ||
break | ||
|
||
with wait_seq: | ||
wait_seq.wait(timeout=wait_time) | ||
value = wait_seq.value | ||
|
||
last_queued_sequence_id = self._operations_queue.last_sequence_id | ||
last_queued_sequence_id = self._sequence_tracker.last_sequence_id | ||
|
||
if value == -1: | ||
if self._operations_queue.last_sequence_id != -1: | ||
if self._sequence_tracker.last_sequence_id != -1: | ||
last_print_timestamp = print_message( | ||
f"Waiting. No operations were {phrase} yet. Operations to sync: %s", | ||
self._operations_queue.last_sequence_id + 1, | ||
self._sequence_tracker.last_sequence_id + 1, | ||
last_print=last_print_timestamp, | ||
verbose=verbose, | ||
) | ||
|
@@ -692,3 +700,35 @@ def print_message(msg: str, *args: Any, last_print: Optional[float] = None, verb | |
return current_time | ||
|
||
return last_print | ||
|
||
|
||
def _create_repository_path(project: str, run_id: str) -> Path: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: ideally we need to come up with something unambiguous, but usable (url encoding is not usable). Need to think some more. |
||
sanitized_project = project.replace("/", "_") | ||
return Path(os.getcwd()) / ".neptune" / f"{sanitized_project}_{run_id}.sqlite3" | ||
|
||
|
||
def _validate_existing_db( | ||
existing_metadata: Metadata, | ||
resume: bool, | ||
project: str, | ||
run_id: str, | ||
fork_run_id: Optional[str], | ||
fork_step: Optional[float], | ||
) -> None: | ||
if existing_metadata.version != DB_VERSION: | ||
raise NeptuneLocalStorageInUnsupportedVersion() | ||
|
||
if existing_metadata.project != project or existing_metadata.run_id != run_id: | ||
# should never happen because we use project and run_id to create the repository path | ||
raise NeptuneConflictingDataInLocalStorage() | ||
|
||
# Check for conflicts when not resuming, because we don't allow fork points in resumed runs | ||
if resume: | ||
return | ||
|
||
if existing_metadata.parent_run_id == fork_run_id and existing_metadata.fork_step == fork_step: | ||
logger.warning("Run already exists in local storage with the same parent run and fork point. Resuming the run.") | ||
return | ||
else: | ||
# Same run_id but different fork points | ||
raise NeptuneConflictingDataInLocalStorage() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.