Skip to content

Use ProcessLink for tracking process termination between Run <-> SyncProcess #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import atexit
import os
import platform
import signal
import threading
import time
from contextlib import AbstractContextManager
Expand Down Expand Up @@ -43,6 +41,7 @@
from neptune_scale.core.components.sync_process import SyncProcess
from neptune_scale.core.logger import get_logger
from neptune_scale.core.metadata_splitter import MetadataSplitter
from neptune_scale.core.process_link import ProcessLink
from neptune_scale.core.serialization import (
datetime_to_proto,
make_step,
Expand All @@ -51,7 +50,6 @@
SharedFloat,
SharedInt,
)
from neptune_scale.core.util import safe_signal_name
from neptune_scale.core.validation import (
verify_collection_type,
verify_max_length,
Expand Down Expand Up @@ -220,11 +218,13 @@ def __init__(
self._last_ack_seq = SharedInt(-1)
self._last_ack_timestamp = SharedFloat(-1)

self._process_link = ProcessLink()
self._sync_process = SyncProcess(
project=self._project,
family=self._run_id,
operations_queue=self._operations_queue.queue,
errors_queue=self._errors_queue,
process_link=self._process_link,
api_token=input_api_token,
last_queued_seq=self._last_queued_seq,
last_ack_seq=self._last_ack_seq,
Expand All @@ -246,13 +246,10 @@ def __init__(
self._errors_monitor.start()
with self._lock:
self._sync_process.start()
self._process_link.start(on_link_closed=self._on_child_link_closed)

self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close)

if platform.system() != "Windows":
# Ignoring the type because the signal module is not available on Windows
signal.signal(signal.SIGCHLD, self._handle_signal) # type: ignore[attr-defined]

if not resume:
self._create_run(
creation_time=datetime.now() if creation_time is None else creation_time,
Expand All @@ -262,16 +259,12 @@ def __init__(
)
self.wait_for_processing(verbose=False)

def _handle_signal(self, signum: int, frame: Any) -> None:
# We should not be concerned about SIGCHLD if it's not about our child process
if signum == signal.SIGCHLD and self._sync_process.is_alive():
return

if not self._is_closing:
signame = safe_signal_name(signum)
logger.debug(f"Received signal {signame}. Terminating.")

self.terminate()
def _on_child_link_closed(self, _: ProcessLink) -> None:
with self._lock:
if not self._is_closing:
logger.error("Child process closed unexpectedly. Terminating.")
self._is_closing = True
self.terminate()

@property
def resources(self) -> tuple[Resource, ...]:
Expand Down Expand Up @@ -303,6 +296,7 @@ def _close(self, *, wait: bool = True) -> None:

self._sync_process.terminate()
self._sync_process.join()
self._process_link.stop()

if self._lag_tracker is not None:
self._lag_tracker.interrupt()
Expand Down
9 changes: 8 additions & 1 deletion src/neptune_scale/core/components/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
SingleOperation,
)
from neptune_scale.core.logger import get_logger
from neptune_scale.core.process_link import ProcessLink
from neptune_scale.core.shared_var import (
SharedFloat,
SharedInt,
Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(
self,
operations_queue: Queue,
errors_queue: ErrorsQueue,
process_link: ProcessLink,
api_token: str,
project: str,
family: str,
Expand All @@ -208,6 +210,7 @@ def __init__(

self._external_operations_queue: Queue[SingleOperation] = operations_queue
self._errors_queue: ErrorsQueue = errors_queue
self._process_link: ProcessLink = process_link
self._api_token: str = api_token
self._project: str = project
self._family: str = family
Expand All @@ -224,10 +227,14 @@ def _handle_signal(self, signum: int, frame: Optional[FrameType]) -> None:
logger.debug(f"Received signal {safe_signal_name(signum)}")
self._stop_event.set() # Trigger the stop event

def _on_parent_link_closed(self, _: ProcessLink) -> None:
logger.error("SyncProcess: main process closed unexpectedly. Exiting")
self._stop_event.set()

def run(self) -> None:
logger.info("Data synchronization started")

# Register signals handlers
self._process_link.start(on_link_closed=self._on_parent_link_closed)
signal.signal(signal.SIGTERM, self._handle_signal)

worker = SyncProcessWorker(
Expand Down
Loading