Skip to content

Commit 3418861

Browse files
zhongchun不涸
andauthored
[Learn] Fix lightgbm machines (#3351)
* add worker ip meta for chunk meta * fix expect worker * fix CollectPorts * fix test_collect_ports * Optimize lightgbm train * Fix meta setting * Add get_node_ip_address * Fix test_local_ranker * Ignore ip when no lightgbm * Lint * add lightgbm tests running on ray --------- Co-authored-by: 不涸 <zhongchun.yzc@antgroup.com>
1 parent fa207a6 commit 3418861

File tree

7 files changed

+97
-15
lines changed

7 files changed

+97
-15
lines changed

.github/workflows/platform-ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ jobs:
108108
pip install "xgboost_ray<0.1.14" "protobuf<4"
109109
# Ray Datasets need pyarrow>=6.0.1
110110
pip install "pyarrow>=6.0.1"
111+
pip install lightgbm
111112
fi
112113
if [ -n "$RUN_DASK" ]; then
113114
pip install "dask[complete]" "mimesis<9.0.0" scikit-learn

mars/deploy/oscar/tests/test_ray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from .... import tensor as mt
2525
from .... import dataframe as md
26+
from ....learn.contrib.lightgbm.tests import test_classifier
2627
from ....oscar.errors import ReconstructWorkerError
2728
from ....session import get_default_session, new_session
2829
from ....tests.core import require_ray, mock, DICT_NOT_EMPTY
@@ -335,3 +336,8 @@ def test_init_metrics_on_ray(ray_start_regular_shared, create_cluster):
335336
assert api._metric_backend == "ray"
336337

337338
client.session.stop_server()
339+
340+
341+
@require_ray
342+
def test_lightgbm_classifier_on_ray(ray_start_regular_shared, create_cluster):
343+
test_classifier.test_local_classifier(create_cluster)

mars/learn/contrib/lightgbm/_train.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,19 @@ def __call__(self):
211211
@staticmethod
212212
def _get_data_chunks_workers(ctx, data):
213213
# data_chunk.inputs is concat, and concat's input is the co-allocated chunks
214-
metas = ctx.get_chunks_meta([c.key for c in data.chunks], fields=["bands"])
215-
return [m["bands"][0][0] for m in metas]
214+
metas = ctx.get_chunks_meta(
215+
[c.key for c in data.chunks], fields=["ip", "bands"]
216+
)
217+
218+
ips = []
219+
ip_to_worker = {}
220+
for m in metas:
221+
ip = m["ip"]
222+
assert ip, "There is meta {meta} who doesn't contain ip."
223+
ips.append(ip)
224+
bands = m["bands"]
225+
ip_to_worker[ip] = bands[0][0] if bands else None
226+
return ips, ip_to_worker
216227

217228
@staticmethod
218229
def _concat_chunks_by_worker(chunks, chunk_workers):
@@ -230,23 +241,24 @@ def tile(cls, op: "LGBMTrain"):
230241
data = op.data
231242
worker_to_args = defaultdict(dict)
232243

233-
workers = cls._get_data_chunks_workers(ctx, data)
244+
# Note: Mars worker is band address, and LGBMTrain worker is machine ip.
245+
ips, ip_to_worker = cls._get_data_chunks_workers(ctx, data)
234246

235247
for arg in ["_data", "_label", "_sample_weight", "_init_score"]:
236248
if getattr(op, arg) is not None:
237249
for worker, chunk in cls._concat_chunks_by_worker(
238-
getattr(op, arg).chunks, workers
250+
getattr(op, arg).chunks, ips
239251
).items():
240252
worker_to_args[worker][arg] = chunk
241253

242254
if op.eval_datas:
243255
eval_workers_list = [
244-
cls._get_data_chunks_workers(ctx, d) for d in op.eval_datas
256+
cls._get_data_chunks_workers(ctx, d)[0] for d in op.eval_datas
245257
]
246258
extra_workers = reduce(
247259
operator.or_, (set(w) for w in eval_workers_list)
248-
) - set(workers)
249-
worker_remap = dict(zip(extra_workers, itertools.cycle(workers)))
260+
) - set(ips)
261+
worker_remap = dict(zip(extra_workers, itertools.cycle(ips)))
250262
if worker_remap:
251263
eval_workers_list = [
252264
[worker_remap.get(w, w) for w in wl] for wl in eval_workers_list
@@ -270,10 +282,11 @@ def tile(cls, op: "LGBMTrain"):
270282
worker_to_args[worker][arg].append(chunk)
271283

272284
out_chunks = []
273-
workers = list(set(workers))
274-
for worker_id, worker in enumerate(workers):
285+
ips = list(set(ips))
286+
workers = list(ip_to_worker.values())
287+
for worker_id, worker in enumerate(ips):
275288
chunk_op = op.copy().reset_key()
276-
chunk_op.expect_worker = worker
289+
chunk_op.expect_worker = ip_to_worker[worker]
277290

278291
input_chunks = []
279292
concat_args = worker_to_args.get(worker, {})
@@ -301,7 +314,7 @@ def tile(cls, op: "LGBMTrain"):
301314
).chunks[0]
302315
input_chunks.append(worker_ports_chunk)
303316

304-
chunk_op._workers = workers
317+
chunk_op._workers = ips
305318
chunk_op._worker_ports = worker_ports_chunk
306319
chunk_op._worker_id = worker_id
307320

@@ -357,9 +370,8 @@ def execute(cls, ctx, op: "LGBMTrain"):
357370
# if model is trained, remove unsupported parameters
358371
params.pop("out_dtype_", None)
359372
worker_ports = ctx[op.worker_ports.key]
360-
worker_ips = [worker.split(":", 1)[0] for worker in op.workers]
361373
worker_endpoints = [
362-
f"{worker}:{port}" for worker, port in zip(worker_ips, worker_ports)
374+
f"{worker}:{port}" for worker, port in zip(op.workers, worker_ports)
363375
]
364376

365377
params["machines"] = ",".join(worker_endpoints)

mars/learn/contrib/lightgbm/tests/test_classifier.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from ..... import tensor as mt
2323
from ..... import dataframe as md
24+
from .....deploy.oscar.local import new_cluster
2425

2526
try:
2627
import lightgbm
@@ -45,8 +46,22 @@
4546
X_sparse = mt.tensor(x_sparse, chunk_size=chunk_size).tosparse(missing=np.nan)[filter]
4647

4748

49+
@pytest.mark.parametrize(indirect=True)
50+
@pytest.fixture
51+
async def create_cluster():
52+
start_method = os.environ.get("POOL_START_METHOD", None)
53+
client = await new_cluster(
54+
subprocess_start_method=start_method,
55+
n_worker=2,
56+
n_cpu=4,
57+
use_uvloop=False,
58+
)
59+
async with client:
60+
yield client
61+
62+
4863
@pytest.mark.skipif(lightgbm is None, reason="LightGBM not installed")
49-
def test_local_classifier(setup):
64+
def test_local_classifier(create_cluster):
5065
y_data = (y * 10).astype(mt.int32)
5166
classifier = LGBMClassifier(n_estimators=2)
5267
classifier.fit(X, y_data, eval_set=[(X, y_data)], verbose=True)

mars/services/meta/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class _ChunkMeta(_CommonMeta):
6868
bands: List[BandType] = None
6969
# needed by ray ownership to keep object alive when worker died.
7070
object_refs: List[Any] = None
71+
ip: str = None
7172

7273
def merge_from(self, value: "_ChunkMeta"):
7374
if value.bands:

mars/services/subtask/worker/processor.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
from ....optimization.physical import optimize
3333
from ....serialization import AioSerializer
3434
from ....typing import BandType, ChunkType
35-
from ....utils import get_chunk_key_to_data_keys, calc_data_size
35+
from ....utils import (
36+
calc_data_size,
37+
get_chunk_key_to_data_keys,
38+
get_node_ip_address,
39+
lazy_import,
40+
)
3641
from ...context import ThreadedServiceContext
3742
from ...meta.api import MetaAPI, WorkerMetaAPI
3843
from ...session import SessionAPI
@@ -41,6 +46,9 @@
4146
from ..core import Subtask, SubtaskStatus, SubtaskResult
4247
from ..utils import iter_input_data_keys, iter_output_data, get_mapper_data_keys
4348

49+
50+
lightgbm = lazy_import("lightgbm")
51+
4452
logger = logging.getLogger(__name__)
4553

4654

@@ -464,6 +472,10 @@ async def _store_meta(
464472
bands=[self._band],
465473
chunk_key=chunk_key,
466474
exclude_fields=["object_ref"],
475+
# Note: Why add the `ip` field?
476+
# lightgbm needs the machine addresses where the data are
477+
# to implement distributed learning.
478+
ip=get_node_ip_address() if lightgbm else None,
467479
)
468480
)
469481
# for supervisor, only save basic meta that is small like memory_size etc

mars/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,3 +1877,38 @@ def retry_call(*args, **kwargs):
18771877
raise ex # pylint: disable-msg=E0702
18781878

18791879
return retry_call
1880+
1881+
1882+
# `get_node_ip_address` is taken from Ray.
1883+
# https://github.com/ray-project/ray/blob/master/python/ray/_private/services.py#L617
1884+
def get_node_ip_address(address="8.8.8.8:53"):
1885+
"""Determine the IP address of the local node.
1886+
1887+
Args:
1888+
address (str): The IP address and port of any known live service on the
1889+
network you care about.
1890+
1891+
Returns:
1892+
The IP address of the current node.
1893+
"""
1894+
ip_address, port = address.split(":")
1895+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1896+
try:
1897+
# This command will raise an exception if there is no internet
1898+
# connection.
1899+
s.connect((ip_address, int(port)))
1900+
node_ip_address = s.getsockname()[0]
1901+
except Exception as e: # pragma: no cover
1902+
node_ip_address = "127.0.0.1"
1903+
# [Errno 101] Network is unreachable
1904+
if e.errno == 101:
1905+
try:
1906+
# try get node ip address from host name
1907+
host_name = socket.getfqdn(socket.gethostname())
1908+
node_ip_address = socket.gethostbyname(host_name)
1909+
except Exception: # noqa: E722 # nosec # pylint: disable=bare-except
1910+
pass
1911+
finally:
1912+
s.close()
1913+
1914+
return node_ip_address

0 commit comments

Comments
 (0)