Skip to content

Commit a464573

Browse files
authored
Stop importing main module when starting Mars local cluster (#3110)
1 parent 24290fe commit a464573

File tree

6 files changed

+119
-14
lines changed

6 files changed

+119
-14
lines changed

.github/workflows/benchmark-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,5 @@ jobs:
6262
uses: actions/upload-artifact@v2
6363
with:
6464
name: Benchmarks log
65-
path: benchmarks/asv_bench/benchmarks.log
65+
path: benchmarks/asv_bench/results
6666
if: failure()

benchmarks/asv_bench/benchmarks/storage.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import itertools
1616

17+
import cloudpickle
1718
import numpy as np
1819
import pandas as pd
1920

@@ -89,8 +90,20 @@ class TransferPackageSuite:
8990
"""
9091

9192
def setup(self):
93+
try:
94+
# make sure all submodules will serial functions instead of refs
95+
cloudpickle.register_pickle_by_value(__import__("benchmarks.storage"))
96+
except (AttributeError, ImportError):
97+
pass
9298
mars.new_session(n_worker=2, n_cpu=8)
9399

100+
def teardown(self):
101+
mars.stop_server()
102+
try:
103+
cloudpickle.unregister_pickle_by_value(__import__("benchmarks.storage"))
104+
except (AttributeError, ImportError):
105+
pass
106+
94107
def time_1_to_1(self):
95108
return mr.spawn(send_1_to_1).execute().fetch()
96109

@@ -99,3 +112,4 @@ def time_1_to_1(self):
99112
suite = TransferPackageSuite()
100113
suite.setup()
101114
print(suite.time_1_to_1())
115+
suite.teardown()

mars/deploy/oscar/tests/test_local.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
import asyncio
1616
import copy
1717
import os
18+
import subprocess
19+
import sys
1820
import threading
1921
import tempfile
22+
import textwrap
2023
import time
2124
import uuid
2225

2326
import numpy as np
2427
import pandas as pd
28+
import psutil
2529
import pytest
2630

2731
try:
@@ -917,3 +921,49 @@ def time_consuming(start, x):
917921
.fetch()
918922
== pd.Series(list(range(series_size))).apply(lambda x: x * x).sum()
919923
)
924+
925+
926+
def test_naive_code_file():
927+
code_file = """
928+
import mars
929+
import mars.tensor as mt
930+
import os
931+
932+
mars.new_session()
933+
try:
934+
result_path = os.environ["RESULTPATH"]
935+
with open(result_path, "w") as outf:
936+
outf.write(str(mt.ones((10, 10)).sum().execute()))
937+
finally:
938+
mars.stop_server()
939+
"""
940+
941+
with tempfile.TemporaryDirectory() as temp_dir:
942+
try:
943+
script_path = os.path.join(temp_dir, "test_file.py")
944+
result_path = os.path.join(temp_dir, "result.txt")
945+
946+
with open(script_path, "w") as file_obj:
947+
file_obj.write(textwrap.dedent(code_file))
948+
949+
env = os.environ.copy()
950+
env["PYTHONPATH"] = os.path.pathsep.join(sys.path)
951+
env["RESULTPATH"] = result_path
952+
proc = subprocess.Popen([sys.executable, script_path], env=env)
953+
pid = proc.pid
954+
proc.wait(120)
955+
956+
with open(result_path, "r") as inp_file:
957+
assert 100 == int(float(inp_file.read()))
958+
except subprocess.TimeoutExpired:
959+
try:
960+
procs = [psutil.Process(pid)]
961+
procs.extend(procs[0].children(True))
962+
for proc in reversed(procs):
963+
try:
964+
proc.kill()
965+
except psutil.NoSuchProcess:
966+
pass
967+
except psutil.NoSuchProcess:
968+
pass
969+
raise

mars/oscar/backends/mars/pool.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
import asyncio
1616
import concurrent.futures as futures
17+
import contextlib
1718
import logging.config
1819
import multiprocessing
1920
import os
2021
import random
2122
import signal
2223
import sys
24+
import threading
2325
import uuid
2426
from dataclasses import dataclass
2527
from types import TracebackType
@@ -63,6 +65,37 @@ def _mp_kill(self):
6365
BaseProcess.kill = _mp_kill
6466

6567
logger = logging.getLogger(__name__)
68+
_init_main_suspended_local = threading.local()
69+
70+
71+
def _patch_spawn_get_preparation_data():
72+
try:
73+
from multiprocessing import spawn as mp_spawn
74+
75+
_raw_get_preparation_data = mp_spawn.get_preparation_data
76+
77+
def _patched_get_preparation_data(*args, **kw):
78+
ret = _raw_get_preparation_data(*args, **kw)
79+
if getattr(_init_main_suspended_local, "value", False):
80+
# make sure user module is not imported when start Mars cluster
81+
ret.pop("init_main_from_name", None)
82+
ret.pop("init_main_from_path", None)
83+
return ret
84+
85+
_patched_get_preparation_data._mars_patched = True
86+
if not getattr(mp_spawn.get_preparation_data, "_mars_patched", False):
87+
mp_spawn.get_preparation_data = _patched_get_preparation_data
88+
except (ImportError, AttributeError): # pragma: no cover
89+
pass
90+
91+
92+
@contextlib.contextmanager
93+
def _suspend_init_main():
94+
try:
95+
_init_main_suspended_local.value = True
96+
yield
97+
finally:
98+
_init_main_suspended_local.value = False
6699

67100

68101
@dataslots
@@ -131,21 +164,25 @@ async def start_sub_pool(
131164
def start_pool_in_process():
132165
ctx = multiprocessing.get_context(method=start_method)
133166
status_queue = ctx.Queue()
134-
process = ctx.Process(
135-
target=cls._start_sub_pool,
136-
args=(actor_pool_config, process_index, status_queue),
137-
name=f"MarsActorPool{process_index}",
138-
)
139-
process.daemon = True
140-
process.start()
167+
168+
with _suspend_init_main():
169+
process = ctx.Process(
170+
target=cls._start_sub_pool,
171+
args=(actor_pool_config, process_index, status_queue),
172+
name=f"MarsActorPool{process_index}",
173+
)
174+
process.daemon = True
175+
process.start()
176+
141177
# wait for sub actor pool to finish starting
142178
process_status = status_queue.get()
143179
return process, process_status
144180

181+
_patch_spawn_get_preparation_data()
145182
loop = asyncio.get_running_loop()
146-
executor = futures.ThreadPoolExecutor(1)
147-
create_pool_task = loop.run_in_executor(executor, start_pool_in_process)
148-
return await create_pool_task
183+
with futures.ThreadPoolExecutor(1) as executor:
184+
create_pool_task = loop.run_in_executor(executor, start_pool_in_process)
185+
return await create_pool_task
149186

150187
@classmethod
151188
async def wait_sub_pools_ready(cls, create_pool_tasks: List[asyncio.Task]):
@@ -240,8 +277,11 @@ async def kill_sub_pool(
240277
pass
241278
process.terminate()
242279
wait_pool = futures.ThreadPoolExecutor(1)
243-
loop = asyncio.get_running_loop()
244-
await loop.run_in_executor(wait_pool, process.join, 3)
280+
try:
281+
loop = asyncio.get_running_loop()
282+
await loop.run_in_executor(wait_pool, process.join, 3)
283+
finally:
284+
wait_pool.shutdown(False)
245285
process.kill()
246286
await asyncio.to_thread(process.join, 5)
247287

mars/oscar/debug.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def _log_timeout(timeout, msg, *args, **kwargs):
7979
await asyncio.sleep(timeout * rnd)
8080
rnd += 1
8181
logger.warning(
82-
msg + "(timeout for %.4f seconds).",
82+
msg + " (timeout for %.4f seconds).",
8383
*args,
8484
time.time() - start_time,
8585
**kwargs,

mars/tests/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def write(self, d):
407407
print("LINE 2", file=sys.stderr, end="\n")
408408
finally:
409409
sys.stdout, sys.stderr = old_stdout, old_stderr
410+
executor.shutdown(False)
410411

411412
assert stdout_w.content == "LINE T\nLINE 1\n"
412413
assert stderr_w.content == "LINE 2\n"

0 commit comments

Comments
 (0)