Skip to content

restore patch on SpawnProcess.terminate #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ warn_unused_ignores = "True"
ignore_missing_imports = "True"

[tool.pytest.ini_options]
addopts = "--doctest-modules -n auto"
addopts = "--doctest-modules -n auto -o log_cli=true --log-cli-level=DEBUG --log-level=DEBUG"

[tool.poetry.scripts]
neptune = "neptune_scale.cli.commands:main"
20 changes: 16 additions & 4 deletions src/neptune_scale/api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@

logger = get_logger()

# Global multiprocessing context for the whole program
GLOBAL_MP_CONTEXT = multiprocessing.get_context("spawn")


class Run(AbstractContextManager):
"""
Expand Down Expand Up @@ -292,7 +295,7 @@ def __init__(
if self._api_token is None:
raise NeptuneApiTokenNotProvided()

spawn_mp_context = multiprocessing.get_context("spawn")
spawn_mp_context = GLOBAL_MP_CONTEXT

self._errors_queue: Optional[ErrorsQueue] = ErrorsQueue(spawn_mp_context)
self._errors_monitor: Optional[ErrorsMonitor] = ErrorsMonitor(
Expand Down Expand Up @@ -377,6 +380,7 @@ def __init__(
self._console_log_capture.start()

def _handle_sync_process_death(self) -> None:
logger.info(f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] Handling sync process death.")
with self._lock:
if not self._is_closing:
if self._errors_queue is not None:
Expand All @@ -385,7 +389,7 @@ def _handle_sync_process_death(self) -> None:
def _close(self, *, timeout: Optional[float] = None) -> None:
timer = Timer(timeout)

# Console log capture is actually a produced of logged data, so let's flush it before closing.
# Console log capture is actually a producer of logged data, so let's flush it before closing.
# This allows to log tracebacks of the potential exception that caused the run to fail.
if self._console_log_capture is not None:
self._console_log_capture.interrupt(remaining_iterations=0 if timer.is_expired() else 1)
Expand Down Expand Up @@ -1034,10 +1038,18 @@ def _wait(
value = wait_seq.value
if value >= self._sequence_tracker.last_sequence_id:
if verbose:
logger.info(f"All operations were {phrase}")
logger.info(
f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] "
f"All operations were {phrase} (state: "
f"{value} >= {self._sequence_tracker.last_sequence_id})"
)
return True
else:
logger.warning("Waiting interrupted because sync process is not running")
logger.warning(
f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] "
f"Waiting interrupted because sync process is not running (state: "
f"{value} < {self._sequence_tracker.last_sequence_id})"
)
return False

# Handle the case where we get notified on `wait_seq` before we actually wait.
Expand Down
3 changes: 2 additions & 1 deletion src/neptune_scale/cli/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from tqdm import tqdm

from neptune_scale.api.run import GLOBAL_MP_CONTEXT
from neptune_scale.sync.errors_tracking import (
ErrorsMonitor,
ErrorsQueue,
Expand Down Expand Up @@ -85,7 +86,7 @@ def __init__(
self._run_log_file: Path = run_log_file
self._operations_repository: OperationsRepository = OperationsRepository(db_path=run_log_file)

self._spawn_mp_context = multiprocessing.get_context("spawn")
self._spawn_mp_context = GLOBAL_MP_CONTEXT
self._errors_queue: ErrorsQueue = ErrorsQueue(self._spawn_mp_context)
self._last_queued_seq = SharedInt(self._spawn_mp_context, -1)
self._last_ack_seq = SharedInt(self._spawn_mp_context, -1)
Expand Down
3 changes: 2 additions & 1 deletion src/neptune_scale/sync/errors_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__all__ = ("ErrorsQueue", "ErrorsMonitor")

import multiprocessing
import os
import queue
import time
from collections.abc import Callable
Expand Down Expand Up @@ -44,7 +45,7 @@ def close(self) -> None:


def default_error_callback(error: BaseException, last_seen_at: Optional[float]) -> None:
logger.error(error)
logger.error(f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] handling error {error}")


def default_network_error_callback(error: BaseException, last_seen_at: Optional[float]) -> None:
Expand Down
10 changes: 6 additions & 4 deletions src/neptune_scale/sync/sync_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def run_sync_process(
last_ack_seq: SharedInt,
last_ack_timestamp: SharedFloat,
) -> None:
logger.info("Data synchronization started")
logger.info(f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] Data synchronization started")
stop_event = threading.Event()
signal.signal(signal.SIGTERM, ft.partial(_handle_stop_signal, stop_event))

Expand Down Expand Up @@ -274,16 +274,18 @@ def close_all_threads() -> None:
break

if not _is_process_running(parent_process):
logger.error("SyncProcess: parent process closed unexpectedly. Exiting")
logger.error(
f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] SyncProcess: parent process closed unexpectedly. Exiting"
)
break

except KeyboardInterrupt:
logger.debug("KeyboardInterrupt received")
finally:
logger.info("Data synchronization stopping")
logger.info(f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] Data synchronization stopping")
close_all_threads()
operations_repository.close(cleanup_files=False)
logger.info("Data synchronization finished")
logger.info(f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] Data synchronization finished")


def _is_process_running(process: Optional[psutil.Process]) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ def client() -> AuthenticatedClient:


def sleep_3s(**kwargs):
logging.info(f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] running mocked run_sync_process - sleep 3s")
time.sleep(3)


def sleep_10s(**kwargs):
logging.info(f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] running mocked run_sync_process - sleep 10s")
time.sleep(10)
41 changes: 40 additions & 1 deletion tests/e2e/test_log_and_fetch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
import os
import threading
Expand All @@ -6,24 +7,30 @@
datetime,
timezone,
)
from unittest.mock import patch

import numpy as np
from pytest import mark

from neptune_scale.api.run import Run
from neptune_scale.cli.sync import SyncRunner
from neptune_scale.util import get_logger

from .conftest import (
random_series,
unique_path,
unique_path, sleep_3s,
)
from .test_fetcher import (
fetch_attribute_values,
fetch_metric_values,
)
from .test_sync import API_TOKEN

NEPTUNE_PROJECT = os.getenv("NEPTUNE_E2E_PROJECT")
SYNC_TIMEOUT = 30

logger = get_logger()


def test_atoms(run, client, project_name):
"""Set atoms to a value, make sure it's equal when fetched"""
Expand Down Expand Up @@ -158,3 +165,35 @@ def test_async_lag_callback():
# Second callback should be called after logging configs
event.wait(timeout=60)
assert event.is_set()


def test_concurrent(client, project_name, run_init_kwargs):
logger.info("Phase 1") ###

def in_thread():
with Run(mode="offline") as run_1:
db_path = run_1._operations_repo._db_path
for i in range(10_000):
run_1.log_configs(data={f"int-value-{i}": i})

runner = SyncRunner(api_token=API_TOKEN, run_log_file=db_path)
runner.start()
time.sleep(2)

thread = threading.Thread(target=in_thread)
thread.start()
thread.join(timeout=SYNC_TIMEOUT)

logger.info("Phase 2") ###

run_2 = Run(resume=True)
run_2._sync_process.terminate()
run_2._sync_process.join()

for i in range(5):
run_2.log_configs({f"test_concurrent/int-value-{i}": i * 2 + 1})

time.sleep(5)

logger.info("Phase 3") ###
assert not run_2.wait_for_processing(SYNC_TIMEOUT)
2 changes: 2 additions & 0 deletions tests/e2e/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def test_sync_stop_timeout(run_init_kwargs, timeout, hung_method):
db_path = run._operations_repo._db_path
run.log_configs(data={"str-value": "hello-world"})
run.assign_files(files={"a-file": b"content"})
for i in range(15): # 17 logs in total
run.log_configs(data={f"int-value-{i}": i})

runner = SyncRunner(api_token=API_TOKEN, run_log_file=db_path)
runner.start()
Expand Down
7 changes: 6 additions & 1 deletion tests/e2e/test_sync_process_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
OperationType,
)
from neptune_scale.sync.parameters import MAX_SINGLE_OPERATION_SIZE_BYTES
from neptune_scale.util import envs
from neptune_scale.util import (
envs,
get_logger,
)
from tests.e2e.conftest import sleep_3s

NEPTUNE_PROJECT = os.getenv("NEPTUNE_E2E_PROJECT")
logger = get_logger()

# Timeout value for all the tests.
#
Expand Down Expand Up @@ -91,6 +95,7 @@ def test_run_wait_methods_after_sync_process_dies_during_wait(
wait_for_submission, wait_for_processing, wait_for_file_upload
):
"""Make sure we're not blocked forever if the sync process dies before completing all the work."""
logger.error(f"[{os.environ.get('PYTEST_XDIST_TESTRUNUID')}] starting test")

run = Run()
run.log_metrics({"metric": 2}, step=1)
Expand Down
Loading