Skip to content

feat: remove SharedVars, SequenceTracker, status_tracking_queue. Add run_operation_submission table. Return bool status from wait #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 128 additions & 155 deletions src/neptune_scale/api/run.py

Large diffs are not rendered by default.

27 changes: 8 additions & 19 deletions src/neptune_scale/cli/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
SequenceId,
)
from neptune_scale.sync.sync_process import run_sync_process
from neptune_scale.util import (
SharedFloat,
SharedInt,
get_logger,
)
from neptune_scale.util import get_logger
from neptune_scale.util.timer import Timer

logger = get_logger()
Expand Down Expand Up @@ -87,9 +83,6 @@ def __init__(

self._spawn_mp_context = multiprocessing.get_context("spawn")
self._errors_queue: ErrorsQueue = ErrorsQueue(self._spawn_mp_context)
self._last_queued_seq = SharedInt(self._spawn_mp_context, -1)
self._last_ack_seq = SharedInt(self._spawn_mp_context, -1)
self._last_ack_timestamp = SharedFloat(self._spawn_mp_context, -1)

self._log_seq_id_range: Optional[tuple[SequenceId, SequenceId]] = None
self._file_upload_request_init_count: Optional[int] = None
Expand All @@ -99,10 +92,10 @@ def __init__(
def start(
self,
) -> None:
self._log_seq_id_range = self._operations_repository.get_sequence_id_range()
self._log_seq_id_range = self._operations_repository.get_operations_sequence_id_range()
self._file_upload_request_init_count = self._operations_repository.get_file_upload_requests_count()

if self._log_seq_id_range is None:
if self._log_seq_id_range is None and self._file_upload_request_init_count == 0:
logger.info("No operations to process")
return

Expand All @@ -120,9 +113,6 @@ def start(
"operations_repository_path": self._run_log_file,
"errors_queue": self._errors_queue,
"api_token": self._api_token,
"last_queued_seq": self._last_queued_seq,
"last_ack_seq": self._last_ack_seq,
"last_ack_timestamp": self._last_ack_timestamp,
},
)

Expand Down Expand Up @@ -187,14 +177,13 @@ def _wait_operation_submit(self, last_progress: _ProgressStatus, wait_time: floa
return last_progress
assert self._log_seq_id_range is not None

with self._last_ack_seq:
self._last_ack_seq.wait(timeout=wait_time)
last_ack_seq_id = self._last_ack_seq.value
log_seq_id_range = self._operations_repository.get_operations_sequence_id_range()

if last_ack_seq_id != -1:
acked_count = last_ack_seq_id - self._log_seq_id_range[0] + 1
if log_seq_id_range is not None:
acked_count = log_seq_id_range[0] - self._log_seq_id_range[0]
time.sleep(wait_time)
else:
acked_count = 0
acked_count = self._log_seq_id_range[1] - self._log_seq_id_range[0] + 1

return last_progress.updated(progress=acked_count)

Expand Down
56 changes: 17 additions & 39 deletions src/neptune_scale/sync/lag_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,36 @@
import time
from collections.abc import Callable

from neptune_scale.sync.parameters import (
LAG_TRACKER_THREAD_SLEEP_TIME,
LAG_TRACKER_TIMEOUT,
)
from neptune_scale.sync.sequence_tracker import SequenceTracker
from neptune_scale.util import (
Daemon,
SharedFloat,
)
from neptune_scale.sync.operations_repository import OperationsRepository
from neptune_scale.sync.parameters import LAG_TRACKER_THREAD_SLEEP_TIME
from neptune_scale.util import Daemon


class LagTracker(Daemon):
def __init__(
self,
sequence_tracker: SequenceTracker,
last_ack_timestamp: SharedFloat,
operations_repository: OperationsRepository,
async_lag_threshold: float,
on_async_lag_callback: Callable[[], None],
) -> None:
super().__init__(name="LagTracker", sleep_time=LAG_TRACKER_THREAD_SLEEP_TIME)

self._sequence_tracker: SequenceTracker = sequence_tracker
self._last_ack_timestamp: SharedFloat = last_ack_timestamp
self._operations_repository = operations_repository
self._async_lag_threshold: float = async_lag_threshold
self._on_async_lag_callback: Callable[[], None] = on_async_lag_callback

self._callback_triggered: bool = False

def work(self) -> None:
with self._last_ack_timestamp:
self._last_ack_timestamp.wait(timeout=LAG_TRACKER_TIMEOUT)
last_ack_timestamp = self._last_ack_timestamp.value
last_queued_timestamp = self._sequence_tracker.last_timestamp

# No operations were put into the queue
if last_queued_timestamp is None:
return

# No operations were processed by server
if last_ack_timestamp < 0 and not self._callback_triggered:
if time.time() - last_queued_timestamp > self._async_lag_threshold:
self._callback_triggered = True
self._on_async_lag_callback()
return

self._callback_triggered = False
else:
# Some operations were processed by server
if last_queued_timestamp - last_ack_timestamp > self._async_lag_threshold:
if not self._callback_triggered:
self._callback_triggered = True
self._on_async_lag_callback()
return

self._callback_triggered = False
oldest_queued_timestamp = self._operations_repository.get_operations_min_timestamp()
current_timestamp = time.time()

if (
oldest_queued_timestamp is not None
and current_timestamp - oldest_queued_timestamp.timestamp() > self._async_lag_threshold
):
if not self._callback_triggered:
self._on_async_lag_callback()
self._callback_triggered = True
else:
self._callback_triggered = False
172 changes: 147 additions & 25 deletions src/neptune_scale/sync/operations_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
__all__ = ("OperationsRepository", "OperationType", "Operation", "Metadata", "SequenceId", "FileUploadRequest")

import contextlib
import datetime
import os
import sqlite3
import threading
import time
import typing
from contextlib import AbstractContextManager
from dataclasses import dataclass
from datetime import datetime
from enum import IntEnum
from typing import (
Literal,
Expand All @@ -39,10 +39,11 @@

logger = get_logger()

DB_VERSION = "v3"
BACKWARD_COMPATIBLE_DB_VERSIONS = ("v2", "v3")
DB_VERSION = "v4"
BACKWARD_COMPATIBLE_DB_VERSIONS = ("v4",)

SequenceId = typing.NewType("SequenceId", int)
RequestId = typing.NewType("RequestId", str)


class OperationType(IntEnum):
Expand All @@ -59,8 +60,19 @@ class Operation:
operation_size_bytes: int

@property
def ts(self) -> datetime.datetime:
return datetime.datetime.fromtimestamp(self.timestamp / 1000)
def ts(self) -> datetime:
return datetime.fromtimestamp(self.timestamp / 1000)


@dataclass(frozen=True)
class OperationSubmission:
sequence_id: SequenceId
timestamp: int
request_id: RequestId

@property
def ts(self) -> datetime:
return datetime.fromtimestamp(self.timestamp / 1000)


@dataclass(frozen=True)
Expand Down Expand Up @@ -149,6 +161,16 @@ def init_db(self) -> None:
)"""
)

conn.execute(
"""
CREATE TABLE IF NOT EXISTS run_operation_submission (
sequence_id INTEGER PRIMARY KEY,
timestamp INTEGER NOT NULL,
request_id TEXT NOT NULL
)
"""
)

conn.execute(
"""
CREATE TABLE IF NOT EXISTS file_upload_requests (
Expand Down Expand Up @@ -248,7 +270,7 @@ def _insert_update_run_snapshots(self, ops: list[bytes], current_time: int) -> O
except NeptuneUnableToLogData:
if self._log_failure_action == "raise":
raise
if self._log_failure_action == "drop":
else:
logger.error(f"Dropping {len(ops)} operations due to error", exc_info=True)
return None
Comment on lines +273 to 275
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Remove unnecessary else after guard condition (remove-unnecessary-else)

Suggested change
else:
logger.error(f"Dropping {len(ops)} operations due to error", exc_info=True)
return None
logger.error(f"Dropping {len(ops)} operations due to error", exc_info=True)
return None


Expand Down Expand Up @@ -319,6 +341,49 @@ def delete_operations(self, up_to_seq_id: SequenceId) -> int:
# Return the number of rows affected
return cursor.rowcount or 0

def get_operation_count(self, limit: Optional[int] = None) -> int:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
return self._get_table_count(cursor, "run_operations", limit=limit)

def get_operations_sequence_id_range(self) -> Optional[tuple[SequenceId, SequenceId]]:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT MIN(sequence_id), MAX(sequence_id)
FROM run_operations
"""
)

row = cursor.fetchone()
if not row:
return None

min_seq_id, max_seq_id = row
if min_seq_id is None or max_seq_id is None:
return None
return SequenceId(min_seq_id), SequenceId(max_seq_id)

def get_operations_min_timestamp(self) -> Optional[datetime]:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT timestamp
FROM run_operations
ORDER BY sequence_id ASC
LIMIT 1
"""
)

row = cursor.fetchone()
if not row:
return None

(timestamp,) = row
return datetime.fromtimestamp(timestamp / 1000)

def save_metadata(self, project: str, run_id: str) -> None:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
Expand Down Expand Up @@ -363,25 +428,6 @@ def get_metadata(self) -> Optional[Metadata]:

return Metadata(project=project, run_id=run_id)

def get_sequence_id_range(self) -> Optional[tuple[SequenceId, SequenceId]]:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT MIN(sequence_id), MAX(sequence_id)
FROM run_operations
"""
)

row = cursor.fetchone()
if not row:
return None

min_seq_id, max_seq_id = row
if min_seq_id is None or max_seq_id is None:
return None
return SequenceId(min_seq_id), SequenceId(max_seq_id)

def save_file_upload_requests(self, files: list[FileUploadRequest]) -> SequenceId:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
Expand Down Expand Up @@ -440,6 +486,82 @@ def get_file_upload_requests_count(self, limit: Optional[int] = None) -> int:
with contextlib.closing(conn.cursor()) as cursor:
return self._get_table_count(cursor, "file_upload_requests", limit=limit)

def save_operation_submissions(self, submissions: list[OperationSubmission]) -> SequenceId:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
cursor.executemany(
"""
INSERT INTO run_operation_submission (sequence_id, timestamp, request_id)
VALUES (?, ?, ?)
""",
[(status.sequence_id, status.timestamp, status.request_id) for status in submissions],
)
cursor.execute("SELECT last_insert_rowid()")
return SequenceId(cursor.fetchone()[0])

def get_operation_submissions(self, limit: int) -> list[OperationSubmission]:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT sequence_id, timestamp, request_id
FROM run_operation_submission
ORDER BY sequence_id ASC
LIMIT ?
""",
(limit,),
)

rows = cursor.fetchall()
return [
OperationSubmission(
sequence_id=SequenceId(row[0]),
timestamp=row[1],
request_id=RequestId(row[2]),
)
for row in rows
]

def delete_operation_submissions(self, up_to_seq_id: Optional[SequenceId]) -> int:
if up_to_seq_id is not None and up_to_seq_id <= 0:
return 0

with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
if up_to_seq_id is None:
cursor.execute("DELETE FROM run_operation_submission")
else:
cursor.execute(
"DELETE FROM run_operation_submission WHERE sequence_id <= ?",
(up_to_seq_id,),
)

return cursor.rowcount or 0

def get_operation_submission_count(self, limit: Optional[int] = None) -> int:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
return self._get_table_count(cursor, "run_operation_submission", limit=limit)

def get_operation_submission_sequence_id_range(self) -> Optional[tuple[SequenceId, SequenceId]]:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT MIN(sequence_id), MAX(sequence_id)
FROM run_operation_submission
"""
)

row = cursor.fetchone()
if not row:
return None

min_seq_id, max_seq_id = row
if min_seq_id is None or max_seq_id is None:
return None
return SequenceId(min_seq_id), SequenceId(max_seq_id)

def _is_repository_empty(self) -> bool:
with self._get_connection() as conn: # type: ignore
with contextlib.closing(conn.cursor()) as cursor:
Expand Down
3 changes: 0 additions & 3 deletions src/neptune_scale/sync/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@

# User facing
SHUTDOWN_TIMEOUT = 60 # 1 minute
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME = 10
MINIMAL_WAIT_FOR_ACK_SLEEP_TIME = 10
STOP_MESSAGE_FREQUENCY = 5
LAG_TRACKER_TIMEOUT = 1
OPERATION_REPOSITORY_POLL_SLEEP_TIME = 1

# Status tracking
Expand Down
Loading
Loading