diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index 5a044341..c3a5f2e1 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -8,8 +8,6 @@ import atexit import os -import platform -import signal import threading import time from contextlib import AbstractContextManager @@ -43,6 +41,7 @@ from neptune_scale.core.components.sync_process import SyncProcess from neptune_scale.core.logger import get_logger from neptune_scale.core.metadata_splitter import MetadataSplitter +from neptune_scale.core.process_link import ProcessLink from neptune_scale.core.serialization import ( datetime_to_proto, make_step, @@ -51,7 +50,6 @@ SharedFloat, SharedInt, ) -from neptune_scale.core.util import safe_signal_name from neptune_scale.core.validation import ( verify_collection_type, verify_max_length, @@ -220,11 +218,13 @@ def __init__( self._last_ack_seq = SharedInt(-1) self._last_ack_timestamp = SharedFloat(-1) + self._process_link = ProcessLink() self._sync_process = SyncProcess( project=self._project, family=self._run_id, operations_queue=self._operations_queue.queue, errors_queue=self._errors_queue, + process_link=self._process_link, api_token=input_api_token, last_queued_seq=self._last_queued_seq, last_ack_seq=self._last_ack_seq, @@ -246,13 +246,10 @@ def __init__( self._errors_monitor.start() with self._lock: self._sync_process.start() + self._process_link.start(on_link_closed=self._on_child_link_closed) self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close) - if platform.system() != "Windows": - # Ignoring the type because the signal module is not available on Windows - signal.signal(signal.SIGCHLD, self._handle_signal) # type: ignore[attr-defined] - if not resume: self._create_run( creation_time=datetime.now() if creation_time is None else creation_time, @@ -262,16 +259,12 @@ def __init__( ) self.wait_for_processing(verbose=False) - def _handle_signal(self, signum: int, frame: Any) -> None: - # We should not be concerned about SIGCHLD if it's not about our child process - if signum == signal.SIGCHLD and self._sync_process.is_alive(): - return - - if not self._is_closing: - signame = safe_signal_name(signum) - logger.debug(f"Received signal {signame}. Terminating.") - - self.terminate() + def _on_child_link_closed(self, _: ProcessLink) -> None: + with self._lock: + if not self._is_closing: + logger.error("Child process closed unexpectedly. Terminating.") + self._is_closing = True + self.terminate() @property def resources(self) -> tuple[Resource, ...]: @@ -303,6 +296,7 @@ def _close(self, *, wait: bool = True) -> None: self._sync_process.terminate() self._sync_process.join() + self._process_link.stop() if self._lag_tracker is not None: self._lag_tracker.interrupt() diff --git a/src/neptune_scale/core/components/sync_process.py b/src/neptune_scale/core/components/sync_process.py index bd586448..5add8cf3 100644 --- a/src/neptune_scale/core/components/sync_process.py +++ b/src/neptune_scale/core/components/sync_process.py @@ -57,6 +57,7 @@ SingleOperation, ) from neptune_scale.core.logger import get_logger +from neptune_scale.core.process_link import ProcessLink from neptune_scale.core.shared_var import ( SharedFloat, SharedInt, @@ -195,6 +196,7 @@ def __init__( self, operations_queue: Queue, errors_queue: ErrorsQueue, + process_link: ProcessLink, api_token: str, project: str, family: str, @@ -208,6 +210,7 @@ def __init__( self._external_operations_queue: Queue[SingleOperation] = operations_queue self._errors_queue: ErrorsQueue = errors_queue + self._process_link: ProcessLink = process_link self._api_token: str = api_token self._project: str = project self._family: str = family @@ -224,10 +227,14 @@ def _handle_signal(self, signum: int, frame: Optional[FrameType]) -> None: logger.debug(f"Received signal {safe_signal_name(signum)}") self._stop_event.set() # Trigger the stop event + def _on_parent_link_closed(self, _: ProcessLink) -> None: + logger.error("SyncProcess: main process closed unexpectedly. Exiting") + self._stop_event.set() + def run(self) -> None: logger.info("Data synchronization started") - # Register signals handlers + self._process_link.start(on_link_closed=self._on_parent_link_closed) signal.signal(signal.SIGTERM, self._handle_signal) worker = SyncProcessWorker(